# -*- coding:utf-8 -*-

import tensorflow as tf
from element.dnn_module import fc_layer
import tensorflow.contrib.slim as slim
from utils.log_utils import log_debug

"""
CNN模块
"""


def conv2d_def(x, out_channel=1, kernel_size=(1, 1), stride=1, padding='VALID', activation_fn=tf.nn.relu):
    """
    普通方式卷积
    :param x: 输入
    :param out_channel:输出channel数
    :param kernel_size: kernel数量
    :param stride: 步长
    :param padding:
    :param activation_fn:
    :return:
    """
    in_channel = x.shape.as_list()[-1]
    strides_4 = [1, stride, stride, 1]
    weight_shape = [kernel_size[0], kernel_size[1], in_channel, out_channel]
    w_conv = tf.Variable(tf.truncated_normal(weight_shape, stddev=0.1))  # 标准差0.1
    b_conv = tf.Variable(tf.constant(0.1, shape=[out_channel]))
    conv = tf.nn.conv2d(x, w_conv, strides=strides_4, padding=padding)
    out = activation_fn(conv + b_conv)
    
    image = tf.split(out, num_or_size_splits=w_conv.shape.as_list()[3], axis=3)
    image = tf.concat(image, axis=0)
    tf.summary.image('out', image, max_outputs=1)

    info = 'conv2d--in:{}, kernel:{}@{}, stride:{}, out:{}'\
        .format(x.shape.as_list(), out_channel, kernel_size, stride, out.shape.as_list())
    log_debug(info, fore='b')
    return out


def conv2d_slim(x, out_channel=1, kernel_size=(1, 1), stride=1, padding='VALID', activation_fn=tf.nn.relu):
    """
    slim conv2d
    :return:
    """
    out = slim.conv2d(
        x, out_channel, kernel_size, stride=stride, padding=padding,
        activation_fn=activation_fn, weights_regularizer=slim.l2_regularizer(0.0001)
    )
    info = 'conv2d--in:{}, kernel:{}@{}, stride:{}, out:{}'\
        .format(x.shape.as_list(), out_channel, kernel_size, stride, out.shape.as_list())
    log_debug(info, fore='b')
    return out


def max_pool_def(x, ksize, strides, padding='SAME'):
    """
    max池化
    :return:
    """
    out = tf.nn.max_pool(x, ksize=ksize, strides=strides, padding=padding)
    info = 'maxpool--in:{}, pool:{}, out:{}'.format(x.shape.as_list(), ksize, out.shape.as_list())
    log_debug(info, fore='b')
    return out


def fc_layer_from_conv2d(x, out_size, activation_func=None):
    """
    卷积的下一个全连接层
    :return:
    """
    expand_size = x.shape.as_list()[1] * x.shape.as_list()[2] * x.shape.as_list()[3]
    x = tf.reshape(x, [-1, expand_size])
    out = fc_layer(x, out_size, activation_func=activation_func)
    return out


def inception_v1():
    pass


def inception_v3():
    pass
